{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "vietnamese-skill",
   "metadata": {},
   "source": [
    "# Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "oriented-applicant",
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairlearn.metrics import (\n",
    "    MetricFrame,\n",
    "    count\n",
    ")\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "legislative-peripheral",
   "metadata": {},
   "outputs": [],
   "source": [
    "from jiwer import wer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "trying-plane",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa3da0e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import warnings\n",
    "\n",
    "warnings.simplefilter(action=\"ignore\")\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.CRITICAL)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "union-array",
   "metadata": {},
   "source": [
    "# Speech-to-Text Fairness Assessment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "impressive-destination",
   "metadata": {},
   "source": [
    "In this notebook, we walk through a scenario of assessing a *speech-to-text* service for fairness-related disparities. For this fairness assessment, we consider various `sensitive features`, such as `native language`, `sex`, and `country` where the speaker is located. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "prospective-demonstration",
   "metadata": {},
   "source": [
    "For this audit, we will be working with a CSV file `stt_testing_data.csv` that contains 2138 speech samples. Each row in the dataset represents a person reading a particular reading passage in English. A machine system is used to generate a transcription from the audio of person reading the passage."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "killing-filter",
   "metadata": {},
   "source": [
    "If you wish to run this notebook with your speech data, you can use the `run_stt.py` provided to query the Microsoft Cognitive Service Speech API. You can also run this notebook with a dataset generated from other speech-to-text systems."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "orange-quarter",
   "metadata": {},
   "outputs": [],
   "source": [
    "import zipfile\n",
    "from raiutils.dataset import fetch_dataset\n",
    "outdirname = 'responsibleai.12.28.21'\n",
    "zipfilename = outdirname + '.zip'\n",
    "\n",
    "fetch_dataset('https://publictestdatasets.blob.core.windows.net/data/' + zipfilename, zipfilename)\n",
    "\n",
    "with zipfile.ZipFile(zipfilename, 'r') as unzip:\n",
    "    unzip.extractall('.')\n",
    "stt_results_csv = \"stt_testing_data.csv\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "818999a5",
   "metadata": {},
   "source": [
    "In this dataset, `ground_truth_text` represents the English passage the participant was asked to read. `predicted_text` represents the transcription produced by the automated service. In an ideal scenario, the `ground_truth_text` would represent the transcription produced by a human transcriber, so we could compare the output of the automated transcription to the one produced by the human transcriber. We will also look at demographic features, such as `sex` of the participant and the `country` where the participant is located."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "swedish-hotel",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(f\"{stt_results_csv}\",\n",
    "                usecols=[\"age\", \"native_language\", \"sex\", \"country\", \"ground_truth_text\", \"predicted_text\"]\n",
    "                )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c98e3b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "radical-first",
   "metadata": {},
   "source": [
    "The goal of a fairness assessment is to *understand which groups of people may be disproportionately negatively impacted by an AI system and in which ways*?\n",
    "\n",
    "For our fairness assessment, we perform the following tasks:\n",
    "\n",
    "1. Idenfity harms and which groups may be harmed.\n",
    "\n",
    "2. Define fairness metrics to quantify harms\n",
    "\n",
    "3. Compare our quantified harms across the relevant groups."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "alike-float",
   "metadata": {},
   "source": [
    "## 1.) Identify Harms and groups who may be harmed"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "electrical-olympus",
   "metadata": {},
   "source": [
    "The first step of the fairness assessment is identifying the types of fairness-related harms we expect users of the systems to experience. From the harms taxonomy in the [Fairlearn User Guide]( https://fairlearn.org/v0.7.0/user_guide/fairness_in_machine_learning.html#types-of-harms), we expect our *speech-to-text* system produces *quality of service* harms to users. \n",
    "\n",
    "*Quality-of-service* harms occur when an AI system does not work as well for one person as it does for others, even when no resources or opportunities are withheld."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abstract-growing",
   "metadata": {},
   "source": [
    "There have been several studies demonstrating that speech-to-text systems achieve different levels of performance based on the speaker's gender and language dialect (Add link to papers). In this assessment, we will explore how the performance of our speech-to-text system differs based on language dialect (proxied by `country`) and `sex` for speakers in three English-speaking countries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "regular-recognition",
   "metadata": {},
   "outputs": [],
   "source": [
    "sensitive_country = [\"country\"]\n",
    "sensitive_country_sex = [\"country\", \"sex\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "straight-mailing",
   "metadata": {},
   "outputs": [],
   "source": [
    "countries = [\"usa\", \"uk\", \"canada\"]\n",
    "filtered_df = df.query(f\"country in {countries} and native_language == 'english'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "otherwise-making",
   "metadata": {},
   "outputs": [],
   "source": [
    "filtered_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "collectible-negotiation",
   "metadata": {},
   "source": [
    "One challenge for our fairness assessment is the small group sample sizes. Our filtered dataset consists primarily of English speakers in the USA, so we expect higher uncertainty for our metrics on speakers from the other two countries. The smaller sample sizes for *UK* and *Canadian* speakers means we may not be able to find significant differences once we also account for `sex`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "actual-split",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(filtered_df.groupby([\"country\"])[\"sex\"].count())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "unique-newfoundland",
   "metadata": {},
   "source": [
    "## 2.) Define fairness metrics to quantify harms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "amber-primary",
   "metadata": {},
   "source": [
    "To measure differences in performance, we will be looking at the `word_error_rate`. The `word_error_rate` represented the fraction of words that are transcribed incorrectly compared to a ground truth text. A higher `word_error_rate` reflects that the system achieves worse performance for a particular group."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "simplified-advertising",
   "metadata": {},
   "source": [
    "Compared to the human transcription (what speaker said is different to ground truth text)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "rational-checkout",
   "metadata": {},
   "outputs": [],
   "source": [
    "DISPARITY_BASE = 0.5\n",
    "\n",
    "def word_error_rate(y_true, y_pred):\n",
    "    return wer(str(y_true), str(y_pred))\n",
    "\n",
    "def wer_abs_disparity(y_true, y_pred, disparity=DISPARITY_BASE):\n",
    "    return (word_error_rate(y_true, y_pred) - disparity)\n",
    "    \n",
    "def wer_rel_disparity(y_true, y_pred, disparity=DISPARITY_BASE):\n",
    "    return wer_abs_disparity(y_true, y_pred, disparity)/disparity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "buried-particular",
   "metadata": {},
   "source": [
    "WER as a disparity from some base. Might be better to compute maximal difference between groups."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "advised-modem",
   "metadata": {},
   "outputs": [],
   "source": [
    "fairness_metrics = {\n",
    "    \"count\": count,\n",
    "    \"word_error_rate\": word_error_rate,\n",
    "    \"word_error_rate_abs_disparity\": wer_abs_disparity,\n",
    "    \"word_error_rate_rel_disparity\": wer_rel_disparity\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "african-personal",
   "metadata": {},
   "source": [
    "## 3.) Compare quantifed harms across different groups"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "editorial-fraction",
   "metadata": {},
   "source": [
    "In the final part of our fairness assessment, we use the `MetricFrame` object in the `fairlearn` package to compare our system's performance across our `sensitive features`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "focal-count",
   "metadata": {},
   "source": [
    "To instanstiate a `MetricFrame`, we pass in four parameters:\n",
    "- `metrics`: The `fairness_metrics` to evaluate each group on.\n",
    "- `y_true`: The `ground_truth_text`\n",
    "- `y_pred`: The `predicted_text`\n",
    "- `sensitive_features`: Our groups for fairness assessment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "shared-happiness",
   "metadata": {},
   "source": [
    "For our first analysis, we look at the system's performance with repsect to `country`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "twelve-duration",
   "metadata": {},
   "outputs": [],
   "source": [
    "metricframe_country = MetricFrame(\n",
    "    metrics=fairness_metrics,\n",
    "    y_true=filtered_df.loc[:, \"ground_truth_text\"],\n",
    "    y_pred=filtered_df.loc[:, \"predicted_text\"],\n",
    "    sensitive_features=filtered_df.loc[:, sensitive_country]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5774b8d7",
   "metadata": {},
   "source": [
    "Using the `MetricFrame`, we can easily compute the `word_error_rate differences` between our three countries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "separate-contents",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(metricframe_country.by_group[[\"count\", \"word_error_rate\"]])\n",
    "display(metricframe_country.difference())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "affecting-calculation",
   "metadata": {},
   "source": [
    "We see the maximal `word_error_rate difference` (between`UK` and `Canada`) is 0.05. Since the `MetricFrame` is built on top of the `Pandas DataFrame` object, we can take advantage of `Pandas`'s plotting capabilities to visualize the `word_error_rate` by `country`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "interior-southwest",
   "metadata": {},
   "outputs": [],
   "source": [
    "metricframe_country.by_group.sort_values(by=\"word_error_rate\", ascending=False).plot(kind=\"bar\", y=\"word_error_rate\", ylabel=\"Word Error Rate\", title=\"Word Error Rate by Country\", figsize=[12,8])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "infrared-courtesy",
   "metadata": {},
   "source": [
    "Next, let's explore how our system performs with respect to `sex` of the speaker. Similar to what we did for `country`, we create another `MetricFrame` except passing in the `sex` column as our `sensitive_features`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "lesser-stretch",
   "metadata": {},
   "outputs": [],
   "source": [
    "metricframe_sex = MetricFrame(\n",
    "    metrics=fairness_metrics,\n",
    "    y_true=filtered_df.loc[:, \"ground_truth_text\"],\n",
    "    y_pred=filtered_df.loc[:, \"predicted_text\"],\n",
    "    sensitive_features=filtered_df.loc[:, \"sex\"]\n",
    "    \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "finite-custody",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(metricframe_sex.by_group[[\"count\", \"word_error_rate\"]])\n",
    "display(metricframe_sex.difference())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "convertible-feeling",
   "metadata": {},
   "source": [
    "In our `sex`-based analysis, we see there is a `0.06` difference in the *WER* between `male` and `female` speakers. If we added uncertainty quantification, such as *confidence intervals*, to our analysis, we could perform statistical tests to determine if the difference is statistically significant or not."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "magnetic-attribute",
   "metadata": {},
   "outputs": [],
   "source": [
    "metricframe_sex.by_group.sort_values(by=\"word_error_rate\", ascending=False).plot(kind=\"bar\", y=\"word_error_rate\", ylabel=\"Word Error Rate\", title=\"Word Error Rate by Sex\", figsize=[12,8])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "departmental-stanley",
   "metadata": {},
   "source": [
    "### Intersectional Analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "great-infection",
   "metadata": {},
   "source": [
    "One key aspect to remember when performing a fairness analysis is to explore the intersection of different groups. For this final analysis, we would look at groups at the intersection of `country` and `sex`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6c98276",
   "metadata": {},
   "source": [
    "In particular, we are interested in seeing the `word_error_rate difference` by `sex` for each `country`. That is, we want to compare the `WER difference` between `Canada male` and `Canada female` to the `WER difference` between `males` and `females` of the other two countries."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6318ee15",
   "metadata": {},
   "source": [
    "When we instantiate our `MetricFrame` this time, we pass in the `country` column as a `control_feature`. Now when we call the `difference` method for our `MetricFrame`, it will compute the `WER difference` for `male` and `female` by each country."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sophisticated-nashville",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make country a control feature\n",
    "metricframe_country_sex_control = MetricFrame(\n",
    "    metrics=fairness_metrics,\n",
    "    y_true=filtered_df.loc[:, \"ground_truth_text\"],\n",
    "    y_pred=filtered_df.loc[:, \"predicted_text\"],\n",
    "    sensitive_features=filtered_df.loc[:, \"sex\"],\n",
    "    control_features=filtered_df.loc[:, \"country\"]\n",
    "        \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2dc6593",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(metricframe_country_sex_control.difference())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45c20c14",
   "metadata": {},
   "source": [
    "If we call the `by_group` attribute, we get the `MultiIndex DataFrame` showing the `count` and `word_error_rate` for each intersectional group."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dd6844f",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(metricframe_country_sex_control.by_group[[\"count\", \"word_error_rate\"]])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b22aac42",
   "metadata": {},
   "source": [
    "Now, let's explore our `word_error_rate` disparity by `sex` within each country. Plotting the absolute `word_error_rates` for each intersectional group shows us that the disparity betwen `UK male` and `UK female` is noticeably larger compared to the other countries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6079cb5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "group_metrics = metricframe_country_sex_control.by_group[[\"count\", \"word_error_rate\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d0ffe3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "group_metrics[\"word_error_rate\"].unstack(level=-1).plot(\n",
    "    kind=\"bar\",\n",
    "    ylabel=\"Word Error Rate\",\n",
    "    title=\"Word Error Rate by Country and Sex\")\n",
    "plt.legend(bbox_to_anchor=(1.3,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31a5a388",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_controlled_features(multiindexframe, title, xaxis, yaxis, order):\n",
    "    \"\"\"\n",
    "    Helper function to plot the visualization for the \n",
    "    \"\"\"\n",
    "    plt.figure(figsize=[12,8])\n",
    "    disagg_metrics = multiindexframe[\"word_error_rate\"].unstack(level=0).loc[:, order].to_dict()\n",
    "    male_scatter = []\n",
    "    female_scatter = []\n",
    "    countries = disagg_metrics.keys()\n",
    "    for country, sex_wer in disagg_metrics.items():\n",
    "        male_point, female_point = sex_wer.get(\"male\"), sex_wer.get(\"female\")\n",
    "        plt.vlines(country, female_point, male_point, linestyles=\"dashed\", alpha=0.45)\n",
    "        #Need to pair X-axis (Country) with each point\n",
    "        male_scatter.append(male_point)\n",
    "        female_scatter.append(female_point)\n",
    "    plt.scatter(countries, male_scatter, marker=\"^\", color=\"b\", label=\"Male\")\n",
    "    plt.scatter(countries, female_scatter, marker=\"s\", color=\"r\", label=\"Female\")\n",
    "    plt.title(title)\n",
    "    plt.legend(bbox_to_anchor=(1,1))\n",
    "    plt.xlabel(xaxis)\n",
    "    plt.ylabel(yaxis)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "977c6d5d",
   "metadata": {},
   "source": [
    "We can also visualize the relative disparity by `sex` for each `country`. From these plots, we see the difference between `UK male` and `UK female` is ~0.09. This is larger than the disparity between `US male` and `US female` (0.06) and the disparity between `Canada male` and `Canada female` (> 0.01)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47265c44",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_controlled_features(group_metrics,\n",
    "                         \"Word Error Rate by Country and Sex\",\n",
    "                         \"Country\",\n",
    "                         \"Word Error Rate\",\n",
    "                        order=[\"uk\", \"usa\", \"canada\"])\n",
    "\n",
    "metricframe_country_sex_control.difference().sort_values(by=\"word_error_rate\", ascending=False).plot(\n",
    "    kind=\"bar\",\n",
    "    y=\"word_error_rate\",\n",
    "    title=\"Word Error Rate between Sex by Country\",\n",
    "    figsize=[12,8])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "waiting-writer",
   "metadata": {},
   "source": [
    "# Conclusion\n",
    "\n",
    "With this fairness assessment, we explored how `country` and `sex` affect the quality of a speech-to-text transcription in three English-speaking countries. Through an intersectional analysis, we found a higher disparity in the *quality-of-service* between `UK male` and `UK female` compared to males and females of other countries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1a4c3bd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
